/*
 * Copyright 2012 The Netty Project
 *
 * The Netty Project licenses this file to you under the Apache License,
 * version 2.0 (the "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 */
package io.netty.channel.socket.nio;

import io.netty.buffer.ByteBuf;
import io.netty.channel.AddressedEnvelope;
import io.netty.channel.Channel;
import io.netty.channel.ChannelException;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelMetadata;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundBuffer;
import io.netty.channel.ChannelPromise;
import io.netty.channel.DefaultAddressedEnvelope;
import io.netty.channel.RecvByteBufAllocator;
import io.netty.channel.nio.AbstractNioMessageChannel;
import io.netty.channel.socket.DatagramChannelConfig;
import io.netty.channel.socket.DatagramPacket;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.StringUtil;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.nio.channels.DatagramChannel;
import java.nio.channels.MembershipKey;
import java.nio.channels.SelectionKey;
import java.nio.channels.spi.SelectorProvider;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/**
 * An NIO datagram {@link Channel} that sends and receives an
 * {@link AddressedEnvelope AddressedEnvelope<ByteBuf, SocketAddress>}.
 *
 * @see AddressedEnvelope
 * @see DatagramPacket
 */
public final class NioDatagramChannel extends AbstractNioMessageChannel
        implements io.netty.channel.socket.DatagramChannel
{

    private static final ChannelMetadata METADATA = new ChannelMetadata(true);

    private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider
            .provider();

    private static final String EXPECTED_TYPES = " (expected: "
            + StringUtil.simpleClassName(DatagramPacket.class) + ", "
            + StringUtil.simpleClassName(AddressedEnvelope.class) + '<'
            + StringUtil.simpleClassName(ByteBuf.class) + ", "
            + StringUtil.simpleClassName(SocketAddress.class) + ">, "
            + StringUtil.simpleClassName(ByteBuf.class) + ')';

    private final DatagramChannelConfig config;

    private Map<InetAddress, List<MembershipKey>> memberships;

    private static DatagramChannel newSocket(SelectorProvider provider)
    {
        try
        {
            /**
             * Use the {@link SelectorProvider} to open {@link SocketChannel}
             * and so remove condition in {@link SelectorProvider#provider()}
             * which is called by each DatagramChannel.open() otherwise.
             *
             * See
             * <a href="https://github.com/netty/netty/issues/2308">#2308</a>.
             */
            return provider.openDatagramChannel();
        }
        catch (IOException e)
        {
            throw new ChannelException("Failed to open a socket.", e);
        }
    }

    private static DatagramChannel newSocket(SelectorProvider provider,
            InternetProtocolFamily ipFamily)
    {
        if (ipFamily == null)
        {
            return newSocket(provider);
        }

        checkJavaVersion();

        try
        {
            return provider.openDatagramChannel(
                    ProtocolFamilyConverter.convert(ipFamily));
        }
        catch (IOException e)
        {
            throw new ChannelException("Failed to open a socket.", e);
        }
    }

    private static void checkJavaVersion()
    {
        if (PlatformDependent.javaVersion() < 7)
        {
            throw new UnsupportedOperationException(
                    "Only supported on java 7+.");
        }
    }

    /**
     * Create a new instance which will use the Operation Systems default
     * {@link InternetProtocolFamily}.
     */
    public NioDatagramChannel()
    {
        this(newSocket(DEFAULT_SELECTOR_PROVIDER));
    }

    /**
     * Create a new instance using the given {@link SelectorProvider} which will
     * use the Operation Systems default {@link InternetProtocolFamily}.
     */
    public NioDatagramChannel(SelectorProvider provider)
    {
        this(newSocket(provider));
    }

    /**
     * Create a new instance using the given {@link InternetProtocolFamily}. If
     * {@code null} is used it will depend on the Operation Systems default
     * which will be chosen.
     */
    public NioDatagramChannel(InternetProtocolFamily ipFamily)
    {
        this(newSocket(DEFAULT_SELECTOR_PROVIDER, ipFamily));
    }

    /**
     * Create a new instance using the given {@link SelectorProvider} and
     * {@link InternetProtocolFamily}. If {@link InternetProtocolFamily} is
     * {@code null} it will depend on the Operation Systems default which will
     * be chosen.
     */
    public NioDatagramChannel(SelectorProvider provider,
            InternetProtocolFamily ipFamily)
    {
        this(newSocket(provider, ipFamily));
    }

    /**
     * Create a new instance from the given {@link DatagramChannel}.
     */
    public NioDatagramChannel(DatagramChannel socket)
    {
        super(null, socket, SelectionKey.OP_READ);
        config = new NioDatagramChannelConfig(this, socket);
    }

    @Override
    public ChannelMetadata metadata()
    {
        return METADATA;
    }

    @Override
    public DatagramChannelConfig config()
    {
        return config;
    }

    @Override
    @SuppressWarnings("deprecation")
    public boolean isActive()
    {
        DatagramChannel ch = javaChannel();
        return ch.isOpen() && (config.getOption(
                ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION)
                && isRegistered() || ch.socket().isBound());
    }

    @Override
    public boolean isConnected()
    {
        return javaChannel().isConnected();
    }

    @Override
    protected DatagramChannel javaChannel()
    {
        return (DatagramChannel) super.javaChannel();
    }

    @Override
    protected SocketAddress localAddress0()
    {
        return javaChannel().socket().getLocalSocketAddress();
    }

    @Override
    protected SocketAddress remoteAddress0()
    {
        return javaChannel().socket().getRemoteSocketAddress();
    }

    @Override
    protected void doBind(SocketAddress localAddress) throws Exception
    {
        doBind0(localAddress);
    }

    private void doBind0(SocketAddress localAddress) throws Exception
    {
        if (PlatformDependent.javaVersion() >= 7)
        {
            SocketUtils.bind(javaChannel(), localAddress);
        }
        else
        {
            javaChannel().socket().bind(localAddress);
        }
    }

    @Override
    protected boolean doConnect(SocketAddress remoteAddress,
            SocketAddress localAddress) throws Exception
    {
        if (localAddress != null)
        {
            doBind0(localAddress);
        }

        boolean success = false;
        try
        {
            javaChannel().connect(remoteAddress);
            success = true;
            return true;
        }
        finally
        {
            if (!success)
            {
                doClose();
            }
        }
    }

    @Override
    protected void doFinishConnect() throws Exception
    {
        throw new Error();
    }

    @Override
    protected void doDisconnect() throws Exception
    {
        javaChannel().disconnect();
    }

    @Override
    protected void doClose() throws Exception
    {
        javaChannel().close();
    }

    @Override
    protected int doReadMessages(List<Object> buf) throws Exception
    {
        DatagramChannel ch = javaChannel();
        DatagramChannelConfig config = config();
        RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle();

        ByteBuf data = allocHandle.allocate(config.getAllocator());
        allocHandle.attemptedBytesRead(data.writableBytes());
        boolean free = true;
        try
        {
            ByteBuffer nioData = data.internalNioBuffer(data.writerIndex(),
                    data.writableBytes());
            int pos = nioData.position();
            InetSocketAddress remoteAddress = (InetSocketAddress) ch
                    .receive(nioData);
            if (remoteAddress == null)
            {
                return 0;
            }

            allocHandle.lastBytesRead(nioData.position() - pos);
            buf.add(new DatagramPacket(
                    data.writerIndex(
                            data.writerIndex() + allocHandle.lastBytesRead()),
                    localAddress(), remoteAddress));
            free = false;
            return 1;
        }
        catch (Throwable cause)
        {
            PlatformDependent.throwException(cause);
            return -1;
        }
        finally
        {
            if (free)
            {
                data.release();
            }
        }
    }

    @Override
    protected boolean doWriteMessage(Object msg, ChannelOutboundBuffer in)
            throws Exception
    {
        final SocketAddress remoteAddress;
        final ByteBuf data;
        if (msg instanceof AddressedEnvelope)
        {
            @SuppressWarnings("unchecked")
            AddressedEnvelope<ByteBuf, SocketAddress> envelope = (AddressedEnvelope<ByteBuf, SocketAddress>) msg;
            remoteAddress = envelope.recipient();
            data = envelope.content();
        }
        else
        {
            data = (ByteBuf) msg;
            remoteAddress = null;
        }

        final int dataLen = data.readableBytes();
        if (dataLen == 0)
        {
            return true;
        }

        final ByteBuffer nioData = data.internalNioBuffer(data.readerIndex(),
                dataLen);
        final int writtenBytes;
        if (remoteAddress != null)
        {
            writtenBytes = javaChannel().send(nioData, remoteAddress);
        }
        else
        {
            writtenBytes = javaChannel().write(nioData);
        }
        return writtenBytes > 0;
    }

    @Override
    protected Object filterOutboundMessage(Object msg)
    {
        if (msg instanceof DatagramPacket)
        {
            DatagramPacket p = (DatagramPacket) msg;
            ByteBuf content = p.content();
            if (isSingleDirectBuffer(content))
            {
                return p;
            }
            return new DatagramPacket(newDirectBuffer(p, content),
                    p.recipient());
        }

        if (msg instanceof ByteBuf)
        {
            ByteBuf buf = (ByteBuf) msg;
            if (isSingleDirectBuffer(buf))
            {
                return buf;
            }
            return newDirectBuffer(buf);
        }

        if (msg instanceof AddressedEnvelope)
        {
            @SuppressWarnings("unchecked")
            AddressedEnvelope<Object, SocketAddress> e = (AddressedEnvelope<Object, SocketAddress>) msg;
            if (e.content() instanceof ByteBuf)
            {
                ByteBuf content = (ByteBuf) e.content();
                if (isSingleDirectBuffer(content))
                {
                    return e;
                }
                return new DefaultAddressedEnvelope<ByteBuf, SocketAddress>(
                        newDirectBuffer(e, content), e.recipient());
            }
        }

        throw new UnsupportedOperationException("unsupported message type: "
                + StringUtil.simpleClassName(msg) + EXPECTED_TYPES);
    }

    /**
     * Checks if the specified buffer is a direct buffer and is composed of a
     * single NIO buffer. (We check this because otherwise we need to make it a
     * non-composite buffer.)
     */
    private static boolean isSingleDirectBuffer(ByteBuf buf)
    {
        return buf.isDirect() && buf.nioBufferCount() == 1;
    }

    @Override
    protected boolean continueOnWriteError()
    {
        // Continue on write error as a DatagramChannel can write to multiple
        // remote peers
        //
        // See https://github.com/netty/netty/issues/2665
        return true;
    }

    @Override
    public InetSocketAddress localAddress()
    {
        return (InetSocketAddress) super.localAddress();
    }

    @Override
    public InetSocketAddress remoteAddress()
    {
        return (InetSocketAddress) super.remoteAddress();
    }

    @Override
    public ChannelFuture joinGroup(InetAddress multicastAddress)
    {
        return joinGroup(multicastAddress, newPromise());
    }

    @Override
    public ChannelFuture joinGroup(InetAddress multicastAddress,
            ChannelPromise promise)
    {
        try
        {
            return joinGroup(multicastAddress, NetworkInterface
                    .getByInetAddress(localAddress().getAddress()), null,
                    promise);
        }
        catch (SocketException e)
        {
            promise.setFailure(e);
        }
        return promise;
    }

    @Override
    public ChannelFuture joinGroup(InetSocketAddress multicastAddress,
            NetworkInterface networkInterface)
    {
        return joinGroup(multicastAddress, networkInterface, newPromise());
    }

    @Override
    public ChannelFuture joinGroup(InetSocketAddress multicastAddress,
            NetworkInterface networkInterface, ChannelPromise promise)
    {
        return joinGroup(multicastAddress.getAddress(), networkInterface, null,
                promise);
    }

    @Override
    public ChannelFuture joinGroup(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress source)
    {
        return joinGroup(multicastAddress, networkInterface, source,
                newPromise());
    }

    @Override
    public ChannelFuture joinGroup(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress source,
            ChannelPromise promise)
    {

        checkJavaVersion();

        if (multicastAddress == null)
        {
            throw new NullPointerException("multicastAddress");
        }

        if (networkInterface == null)
        {
            throw new NullPointerException("networkInterface");
        }

        try
        {
            MembershipKey key;
            if (source == null)
            {
                key = javaChannel().join(multicastAddress, networkInterface);
            }
            else
            {
                key = javaChannel().join(multicastAddress, networkInterface,
                        source);
            }

            synchronized (this)
            {
                List<MembershipKey> keys = null;
                if (memberships == null)
                {
                    memberships = new HashMap<InetAddress, List<MembershipKey>>();
                }
                else
                {
                    keys = memberships.get(multicastAddress);
                }
                if (keys == null)
                {
                    keys = new ArrayList<MembershipKey>();
                    memberships.put(multicastAddress, keys);
                }
                keys.add(key);
            }

            promise.setSuccess();
        }
        catch (Throwable e)
        {
            promise.setFailure(e);
        }

        return promise;
    }

    @Override
    public ChannelFuture leaveGroup(InetAddress multicastAddress)
    {
        return leaveGroup(multicastAddress, newPromise());
    }

    @Override
    public ChannelFuture leaveGroup(InetAddress multicastAddress,
            ChannelPromise promise)
    {
        try
        {
            return leaveGroup(multicastAddress, NetworkInterface
                    .getByInetAddress(localAddress().getAddress()), null,
                    promise);
        }
        catch (SocketException e)
        {
            promise.setFailure(e);
        }
        return promise;
    }

    @Override
    public ChannelFuture leaveGroup(InetSocketAddress multicastAddress,
            NetworkInterface networkInterface)
    {
        return leaveGroup(multicastAddress, networkInterface, newPromise());
    }

    @Override
    public ChannelFuture leaveGroup(InetSocketAddress multicastAddress,
            NetworkInterface networkInterface, ChannelPromise promise)
    {
        return leaveGroup(multicastAddress.getAddress(), networkInterface, null,
                promise);
    }

    @Override
    public ChannelFuture leaveGroup(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress source)
    {
        return leaveGroup(multicastAddress, networkInterface, source,
                newPromise());
    }

    @Override
    public ChannelFuture leaveGroup(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress source,
            ChannelPromise promise)
    {
        checkJavaVersion();

        if (multicastAddress == null)
        {
            throw new NullPointerException("multicastAddress");
        }
        if (networkInterface == null)
        {
            throw new NullPointerException("networkInterface");
        }

        synchronized (this)
        {
            if (memberships != null)
            {
                List<MembershipKey> keys = memberships.get(multicastAddress);
                if (keys != null)
                {
                    Iterator<MembershipKey> keyIt = keys.iterator();

                    while (keyIt.hasNext())
                    {
                        MembershipKey key = keyIt.next();
                        if (networkInterface.equals(key.networkInterface()))
                        {
                            if (source == null && key.sourceAddress() == null
                                    || source != null && source
                                            .equals(key.sourceAddress()))
                            {
                                key.drop();
                                keyIt.remove();
                            }
                        }
                    }
                    if (keys.isEmpty())
                    {
                        memberships.remove(multicastAddress);
                    }
                }
            }
        }

        promise.setSuccess();
        return promise;
    }

    /**
     * Block the given sourceToBlock address for the given multicastAddress on
     * the given networkInterface
     */
    @Override
    public ChannelFuture block(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress sourceToBlock)
    {
        return block(multicastAddress, networkInterface, sourceToBlock,
                newPromise());
    }

    /**
     * Block the given sourceToBlock address for the given multicastAddress on
     * the given networkInterface
     */
    @Override
    public ChannelFuture block(InetAddress multicastAddress,
            NetworkInterface networkInterface, InetAddress sourceToBlock,
            ChannelPromise promise)
    {
        checkJavaVersion();

        if (multicastAddress == null)
        {
            throw new NullPointerException("multicastAddress");
        }
        if (sourceToBlock == null)
        {
            throw new NullPointerException("sourceToBlock");
        }

        if (networkInterface == null)
        {
            throw new NullPointerException("networkInterface");
        }
        synchronized (this)
        {
            if (memberships != null)
            {
                List<MembershipKey> keys = memberships.get(multicastAddress);
                for (MembershipKey key : keys)
                {
                    if (networkInterface.equals(key.networkInterface()))
                    {
                        try
                        {
                            key.block(sourceToBlock);
                        }
                        catch (IOException e)
                        {
                            promise.setFailure(e);
                        }
                    }
                }
            }
        }
        promise.setSuccess();
        return promise;
    }

    /**
     * Block the given sourceToBlock address for the given multicastAddress
     *
     */
    @Override
    public ChannelFuture block(InetAddress multicastAddress,
            InetAddress sourceToBlock)
    {
        return block(multicastAddress, sourceToBlock, newPromise());
    }

    /**
     * Block the given sourceToBlock address for the given multicastAddress
     *
     */
    @Override
    public ChannelFuture block(InetAddress multicastAddress,
            InetAddress sourceToBlock, ChannelPromise promise)
    {
        try
        {
            return block(multicastAddress,
                    NetworkInterface
                            .getByInetAddress(localAddress().getAddress()),
                    sourceToBlock, promise);
        }
        catch (SocketException e)
        {
            promise.setFailure(e);
        }
        return promise;
    }

    @Override
    @Deprecated
    protected void setReadPending(boolean readPending)
    {
        super.setReadPending(readPending);
    }

    void clearReadPending0()
    {
        clearReadPending();
    }

    @Override
    protected boolean closeOnReadError(Throwable cause)
    {
        // We do not want to close on SocketException when using DatagramChannel
        // as we usually can continue receiving.
        // See https://github.com/netty/netty/issues/5893
        if (cause instanceof SocketException)
        {
            return false;
        }
        return super.closeOnReadError(cause);
    }
}
